# Author    :Dun_Hz
# Time      :2024/3/13 16:23

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from PIL import Image


class WaveEquationFD:

    def __init__(self, N, D, Mx, My):
        self.N = N
        self.D = D
        self.Mx = Mx
        self.My = My
        self.tend = 6
        self.xmin = 0
        self.xmax = 2
        self.ymin = 0
        self.ymax = 2
        self.initialization()
        self.eqnApprox()

    def initialization(self):
        self.dx = (self.xmax - self.xmin) / self.Mx
        self.dy = (self.ymax - self.ymin) / self.My

        self.x = np.arange(self.xmin, self.xmax + self.dx, self.dx)
        self.y = np.arange(self.ymin, self.ymax + self.dy, self.dy)

        # ----- Initial condition -----#
        self.u0 = lambda r, s: 0.2 * np.exp(
            -((r - (self.xmax + self.xmin) / 2) ** 2 + (s - (self.ymax + self.ymin) / 2) ** 2) / 0.01)
        # ----- Initial velocity -----#
        self.v0 = lambda a, b: 0

        # ----- Boundary conditions -----#
        self.bxyt = lambda left, right, time: 0

        self.dt = (self.tend - 0) / self.N
        self.t = np.arange(0, self.tend + self.dt / 2, self.dt)

        # Assertion for the condition of r < 1, for stability
        r = 4 * self.D * self.dt ** 2 / (self.dx ** 2 + self.dy ** 2);
        assert r < 1, "r is bigger than 1!"

    def eqnApprox(self):
        # ----- Approximation equation properties -----#
        self.rx = self.D * self.dt ** 2 / self.dx ** 2
        self.ry = self.D * self.dt ** 2 / self.dy ** 2
        self.rxy1 = 1 - self.rx - self.ry
        self.rxy2 = self.rxy1 * 2

        # ----- Initialization matrix u for solution -----#
        self.u = np.zeros((self.Mx + 1, self.My + 1))
        self.ut = np.zeros((self.Mx + 1, self.My + 1))
        self.u_1 = self.u.copy()

        # ----- Fills initial condition and initial velocity -----#
        for j in range(1, self.Mx):
            for i in range(1, self.My):
                self.u[i, j] = self.u0(self.x[i], self.y[j])
                self.ut[i, j] = self.v0(self.x[i], self.y[j])

    def solve_and_animate(self):

        u_2 = np.zeros((self.Mx + 1, self.My + 1))

        xx, yy = np.meshgrid(self.x, self.y)

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        wframe = None
        count = 1
        k = 0
        nsteps = self.N

        while k < nsteps:
            if wframe:
                ax.collections.remove(wframe)

            self.t = k * self.dt

            # ----- Fills in boundary condition along y-axis (vertical, columns 0 and Mx) -----#
            for i in range(self.My + 1):
                self.u[i, 0] = self.bxyt(self.x[0], self.y[i], self.t)
                self.u[i, self.Mx] = self.bxyt(self.x[self.Mx], self.y[i], self.t)

            for j in range(self.Mx + 1):
                self.u[0, j] = self.bxyt(self.x[j], self.y[0], self.t)
                self.u[self.My, j] = self.bxyt(self.x[j], self.y[self.My], self.t)

            if k == 0:
                for j in range(1, self.My):
                    for i in range(1, self.Mx):
                        self.u[i, j] = 0.5 * (self.rx * (self.u_1[i - 1, j] + self.u_1[i + 1, j])) \
                                       + 0.5 * (self.ry * (self.u_1[i, j - 1] + self.u_1[i, j + 1])) \
                                       + self.rxy1 * self.u[i, j] + self.dt * self.ut[i, j]
            else:
                for j in range(1, self.My):
                    for i in range(1, self.Mx):
                        self.u[i, j] = self.rx * (self.u_1[i - 1, j] + self.u_1[i + 1, j]) \
                                       + self.ry * (self.u_1[i, j - 1] + self.u_1[i, j + 1]) \
                                       + self.rxy2 * self.u[i, j] - u_2[i, j]

            u_2 = self.u_1.copy()
            self.u_1 = self.u.copy()

            wframe = ax.plot_surface(xx, yy, self.u, cmap='rainbow', linewidth=2,
                                     antialiased=False)

            ax.set_xlim3d(0, 2.0)
            ax.set_ylim3d(0, 2.0)
            ax.set_zlim3d(-1.5, 1.5)

            ax.set_xticks([0, 0.5, 1.0, 1.5, 2.0])
            ax.set_yticks([0, 0.5, 1.0, 1.5, 2.0])

            ax.set_xlabel("x")
            ax.set_ylabel("y")
            ax.set_zlabel("U")
            plt.savefig(str(count))
            plt.pause(0.01)
            k += 0.5
            count += 1


def main():
    simulator = WaveEquationFD(200, 0.1, 100, 100)
    simulator.solve_and_animate()
    plt.show()


if __name__ == "__main__":
    main()

# N = 200
# D = 0.25
# Mx = 50
# My = 50
imgs = []
